# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of 3D Residual Networks."""
from typing import Any, Callable, List, Optional, Tuple

import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.modeling import tf_utils
from official.projects.const_cl.modeling.backbones import nn_blocks_3d
from official.vision.modeling.backbones import factory
from official.vision.modeling.backbones import resnet_3d
from official.vision.modeling.layers import nn_layers

layers = tf_keras.layers

RESNET_SPECS = resnet_3d.RESNET_SPECS


@tf_keras.utils.register_keras_serializable(package='Vision')
class ResNet3DY(tf_keras.Model):
  """Creates a 3D ResNet family model with branched res5 block."""

  def __init__(
      self,
      model_id: int,
      temporal_strides: List[int],
      temporal_kernel_sizes: List[Tuple[int]],
      use_self_gating: Optional[List[int]] = None,
      input_specs: tf_keras.layers.InputSpec = layers.InputSpec(
          shape=[None, None, None, None, 3]),
      stem_type: str = 'v0',
      stem_conv_temporal_kernel_size: int = 5,
      stem_conv_temporal_stride: int = 2,
      stem_pool_temporal_stride: int = 2,
      init_stochastic_depth_rate: float = 0.0,
      activation: str = 'relu',
      se_ratio: Optional[float] = None,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a 3D ResNet model.

    Args:
      model_id: An `int` of depth of ResNet backbone model.
      temporal_strides: A list of integers that specifies the temporal strides
        for all 3d blocks.
      temporal_kernel_sizes: A list of tuples that specifies the temporal kernel
        sizes for all 3d blocks in different block groups.
      use_self_gating: A list of booleans to specify applying self-gating module
        or not in each block group. If None, self-gating is not applied.
      input_specs: A `tf_keras.layers.InputSpec` of the input tensor.
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
        first conv layer.
      stem_conv_temporal_stride: An `int` of temporal stride for the first conv
        layer.
      stem_pool_temporal_stride: An `int` of temporal stride for the first pool
        layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      activation: A `str` of name of the activation function.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
    super().__init__(**kwargs)

    self._model_id = model_id
    self._temporal_strides = temporal_strides
    self._temporal_kernel_sizes = temporal_kernel_sizes
    self._input_specs = input_specs
    self._stem_type = stem_type
    self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
    self._stem_conv_temporal_stride = stem_conv_temporal_stride
    self._stem_pool_temporal_stride = stem_pool_temporal_stride
    self._use_self_gating = use_self_gating
    self._se_ratio = se_ratio
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    if use_sync_bn:
      self._norm = layers.experimental.SyncBatchNormalization
    else:
      self._norm = layers.BatchNormalization
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1

    # Build ResNet3D backbone.
    inputs = tf_keras.Input(shape=input_specs.shape[1:])
    self._build_model(inputs)

  def _build_model(self, inputs):
    """Builds model architecture.

    Args:
      inputs: the Keras input spec.

    Returns:
      endpoints: A dictionary of backbone endpoint features.
    """
    # Build stem.
    self._build_stem(inputs, stem_type=self._stem_type)

    temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3
    self._max_pool = layers.MaxPool3D(
        pool_size=[temporal_kernel_size, 3, 3],
        strides=[self._stem_pool_temporal_stride, 2, 2],
        padding='same')

    # Build intermediate blocks and endpoints.
    resnet_specs = RESNET_SPECS[self._model_id]
    if len(self._temporal_strides) != len(resnet_specs) or len(
        self._temporal_kernel_sizes) != len(resnet_specs):
      raise ValueError(
          'Number of blocks in temporal specs should equal to resnet_specs.')

    self._blocks = {}
    for i, resnet_spec in enumerate(resnet_specs):
      if resnet_spec[0] == 'bottleneck3d':
        block_fn = nn_blocks_3d.BottleneckBlock3D
      else:
        raise ValueError('Block fn `{}` is not supported.'.format(
            resnet_spec[0]))

      use_self_gating = (
          self._use_self_gating[i] if self._use_self_gating else False)
      self._blocks[f'res_{i+2}'] = self._build_block_group(
          inputs=inputs,
          filters=resnet_spec[1],
          temporal_kernel_sizes=self._temporal_kernel_sizes[i],
          temporal_strides=self._temporal_strides[i],
          spatial_strides=(1 if i == 0 else 2),
          block_fn=block_fn,
          block_repeats=resnet_spec[2],
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
              self._init_stochastic_depth_rate, i + 2, 5),
          use_self_gating=use_self_gating,
          name='res_{}'.format(i + 2))

    # Duplicate res5 block.
    resnet_specs = RESNET_SPECS[self._model_id]
    resnet_spec = resnet_specs[-1]
    i = len(resnet_specs) - 1

    if resnet_spec[0] == 'bottleneck3d':
      block_fn = nn_blocks_3d.BottleneckBlock3D
    else:
      raise ValueError('Block fn `{}` is not supported.'.format(
          resnet_spec[0]))

    use_self_gating = (
        self._use_self_gating[i] if self._use_self_gating else False)
    block_layers = self._build_block_group(
        inputs=inputs,
        filters=resnet_spec[1],
        temporal_kernel_sizes=self._temporal_kernel_sizes[i],
        temporal_strides=self._temporal_strides[i],
        spatial_strides=(1 if i == 0 else 2),
        block_fn=block_fn,
        block_repeats=resnet_spec[2],
        stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
            self._init_stochastic_depth_rate, i + 2, 5),
        use_self_gating=use_self_gating,
        name='res_{}_1'.format(i + 2))
    self._res_5_1_layers = block_layers

  def _build_stem(self, inputs, stem_type):
    """Builds stem layer."""
    del inputs
    # Build stem.
    if stem_type == 'v0':
      self._stem_conv = layers.Conv3D(
          filters=64,
          kernel_size=[self._stem_conv_temporal_kernel_size, 7, 7],
          strides=[self._stem_conv_temporal_stride, 2, 2],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          name='stem')
      self._stem_bn = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          name='stem/batch_norm')
      self._stem_activation = tf_utils.get_activation(self._activation)
    else:
      raise ValueError(f'Stem type {stem_type} not supported.')

  def _build_block_group(
      self,
      inputs: tf.Tensor,
      filters: int,
      temporal_kernel_sizes: Tuple[int],
      temporal_strides: int,
      spatial_strides: int,
      block_fn: Callable[
          ..., tf_keras.layers.Layer] = nn_blocks_3d.BottleneckBlock3D,
      block_repeats: int = 1,
      stochastic_depth_drop_rate: float = 0.0,
      use_self_gating: bool = False,
      name: str = 'block_group'):
    """Creates one group of blocks for the ResNet3D model.

    Args:
      inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
      filters: An `int` of number of filters for the first convolution of the
        layer.
      temporal_kernel_sizes: A tuple that specifies the temporal kernel sizes
        for each block in the current group.
      temporal_strides: An `int` of temporal strides for the first convolution
        in this group.
      spatial_strides: An `int` stride to use for the first convolution of the
        layer. If greater than 1, this layer will downsample the input.
      block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
      block_repeats: An `int` of number of blocks contained in the layer.
      stochastic_depth_drop_rate: A `float` of drop rate of the current block
        group.
      use_self_gating: A `bool` that specifies whether to apply self-gating
        module or not.
      name: A `str` name for the block.

    Returns:
      The output `tf.Tensor` of the block layer.
    """
    del inputs
    if len(temporal_kernel_sizes) != block_repeats:
      raise ValueError(
          'Number of elements in temporal_kernel_sizes must equal to '
          'block_repeats.')

    # Only apply self-gating module in the last block.
    use_self_gating_list = [False] * (block_repeats - 1) + [use_self_gating]

    name = 'cell'
    block_layers = {}
    block_layers[f'{name}_0'] = block_fn(
        filters=filters,
        temporal_kernel_size=temporal_kernel_sizes[0],
        temporal_strides=temporal_strides,
        spatial_strides=spatial_strides,
        stochastic_depth_drop_rate=stochastic_depth_drop_rate,
        use_self_gating=use_self_gating_list[0],
        se_ratio=self._se_ratio,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activation=self._activation,
        use_sync_bn=self._use_sync_bn,
        norm_momentum=self._norm_momentum,
        norm_epsilon=self._norm_epsilon,
        name=f'{name}_0')

    for i in range(1, block_repeats):
      block_layers[f'{name}_{i}'] = block_fn(
          filters=filters,
          temporal_kernel_size=temporal_kernel_sizes[i],
          temporal_strides=1,
          spatial_strides=1,
          stochastic_depth_drop_rate=stochastic_depth_drop_rate,
          use_self_gating=use_self_gating_list[i],
          se_ratio=self._se_ratio,
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          activation=self._activation,
          use_sync_bn=self._use_sync_bn,
          norm_momentum=self._norm_momentum,
          norm_epsilon=self._norm_epsilon,
          name=f'{name}_{i}')

    return block_layers

  def call(self, inputs: tf.Tensor, training: bool = False, mask: Any = None):
    """Calls ResNet3DY model."""
    del mask
    x = self._stem_conv(inputs, training=training)
    x = self._stem_bn(x, training=training)
    x = self._stem_activation(x)
    x = self._max_pool(x)

    res4 = None
    endpoints = {}
    for i, block_layers in enumerate(self._blocks.values()):
      for block_fn in block_layers.values():
        x = block_fn(x, training=training)
      endpoints[f'{i + 2}'] = x
      if i + 2 == 4:
        res4 = x

    for block_fn in self._res_5_1_layers.values():
      res4 = block_fn(res4, training=training)
    endpoints['5_1'] = res4
    return endpoints

  def get_config(self):
    config_dict = {
        'model_id': self._model_id,
        'temporal_strides': self._temporal_strides,
        'temporal_kernel_sizes': self._temporal_kernel_sizes,
        'stem_type': self._stem_type,
        'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size,
        'stem_conv_temporal_stride': self._stem_conv_temporal_stride,
        'stem_pool_temporal_stride': self._stem_pool_temporal_stride,
        'use_self_gating': self._use_self_gating,
        'se_ratio': self._se_ratio,
        'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
        'activation': self._activation,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
    }
    return config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)


@factory.register_backbone_builder('resnet_3dy')
def build_resnet3dy(
    input_specs: tf_keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds ResNet 3d-Y backbone from a config."""
  backbone_cfg = backbone_config.get()

  # Flatten configs before passing to the backbone.
  temporal_strides = []
  temporal_kernel_sizes = []
  use_self_gating = []
  for block_spec in backbone_cfg.block_specs:
    temporal_strides.append(block_spec.temporal_strides)
    temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
    use_self_gating.append(block_spec.use_self_gating)

  return ResNet3DY(
      model_id=backbone_cfg.model_id,
      temporal_strides=temporal_strides,
      temporal_kernel_sizes=temporal_kernel_sizes,
      use_self_gating=use_self_gating,
      input_specs=input_specs,
      stem_type=backbone_cfg.stem_type,
      stem_conv_temporal_kernel_size=backbone_cfg
      .stem_conv_temporal_kernel_size,
      stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
      stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
      init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
      se_ratio=backbone_cfg.se_ratio,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)
